Select
根据条件张量逐元素选择输入值。对于每个输出位置,如果条件为真(True),则选择 input0 的值;否则选择 input1 的值。该算子支持广播机制。
其中,当不需要广播时(is_broadcast = 0),idx1 = idx2 = idx3 = i;当需要广播时(is_broadcast = 1),使用索引映射 index_list1、index_list2、index_list3 来确定各个输入张量的索引。
- 输入:
input0 - 第一个输入数据地址。当条件为真时选择此值。
input1 - 第二个输入数据地址。当条件为假时选择此值。
condition - 条件数据地址(bool类型)。决定选择哪个输入的值。
- params - 其他参数打包成数组。
output_dims - 输出张量的维度信息数组。
output_dims_num - 输出张量的维度数。
index_list1 - 条件张量的索引映射数组,用于广播场景。大小为输出总元素数。
index_list2 - input0 的索引映射数组,用于广播场景。大小为输出总元素数。
index_list3 - input1 的索引映射数组,用于广播场景。大小为输出总元素数。
is_broadcast - 是否需要广播的标志。0 表示不需要广播,1 表示需要广播。
core_mask - 核掩码(仅共享存储版本需要)。
- 输出:
output - 输出数据地址,其形状由 output_dims 和 output_dims_num 确定。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持fp32, int8, int16, int32, fp64, cplx64, cplx128
MT7004 支持fp16, fp32, int16, int32, cplx64
共享存储版本:
-
void i8_select_s(int8_t *input0, int8_t *input1, bool *condition, int8_t *output, long long *params, int core_mask)
-
void i16_select_s(int16_t *input0, int16_t *input1, bool *condition, int16_t *output, long long *params, int core_mask)
-
void i32_select_s(int32_t *input0, int32_t *input1, bool *condition, int32_t *output, long long *params, int core_mask)
-
void hp_select_s(half *input0, half *input1, bool *condition, half *output, long long *params, int core_mask)
-
void fp_select_s(float *input0, float *input1, bool *condition, float *output, long long *params, int core_mask)
-
void dp_select_s(double *input0, double *input1, bool *condition, double *output, long long *params, int core_mask)
-
void c64_select_s(float *input0, float *input1, bool *condition, float *output, long long *params, int core_mask)
-
void c128_select_s(double *input0, double *input1, bool *condition, double *output, long long *params, int core_mask)
C调用示例(无广播):
1//FT78NE示例
2#include <stdio.h>
3#include <select.h>
4
5int main(int argc, char* argv[]) {
6 // 假设在DDR空间
7 float *input0 = (float *)0xA0000000;
8 float *input1 = (float *)0xA1000000;
9 bool *condition = (bool *)0xA2000000;
10 float *output = (float *)0xB0000000;
11
12 // 输出形状 [2, 3, 4]
13 unsigned long long output_dims[] = {2, 3, 4};
14 unsigned long long output_dims_num = 3;
15
16 // 计算总元素数
17 unsigned long long total_elements = 2 * 3 * 4; // 24
18
19 // 索引映射数组(无广播时可以为NULL或与输出索引相同)
20 unsigned long long *index_list1 = (unsigned long long *)0xC0000000;
21 unsigned long long *index_list2 = (unsigned long long *)0xC0100000;
22 unsigned long long *index_list3 = (unsigned long long *)0xC0200000;
23
24 // 初始化索引映射(无广播时直接使用顺序索引)
25 for (unsigned long long i = 0; i < total_elements; i++) {
26 index_list1[i] = i;
27 index_list2[i] = i;
28 index_list3[i] = i;
29 }
30
31 long long is_broadcast = 0; // 不需要广播
32 int core_mask = 0xff;
33
34 fp_select_s(input0, input1, condition, output, output_dims, output_dims_num,
35 index_list1, index_list2, index_list3, is_broadcast, core_mask);
36 return 0;
37}
C调用示例(有广播):
1//FT78NE示例
2#include <stdio.h>
3#include <select.h>
4
5int main(int argc, char* argv[]) {
6 float *input0 = (float *)0x81000000;
7 float *input1 = (float *)0x82000000;
8 bool *condition = (bool *)0x83000000;
9 float *output = (float *)0x84000000;
10 float *checkoutput = (float *)0x85000000;
11 unsigned long long *index_list1 = (unsigned long long *)0x86000000;
12 unsigned long long *index_list2 = (unsigned long long *)0x87000000;
13 unsigned long long *index_list3 = (unsigned long long *)0x87800000;
14
15 long long is_broadcast = 0;
16
17 unsigned long long *input0_dims = global_input0_dims;
18 unsigned long long *input1_dims = global_input1_dims;
19 unsigned long long *cond_dims= global_cond_dims;
20 unsigned long long *output_dims = global_output_dims;
21 unsigned long long input0_dims_num = global_input0_dims_num;
22 unsigned long long input1_dims_num = global_input1_dims_num;
23 unsigned long long cond_dims_num = global_cond_dims_num;
24 unsigned long long output_dims_num = global_output_dims_num;
25
26 unsigned long long params[10];
27
28 params[0] = (unsigned long long)output_dims;
29 params[2] = (unsigned long long)index_list1;
30 params[3] = (unsigned long long)index_list2;
31 params[4] = (unsigned long long)index_list3;
32
33 //先计算is_broadcast
34 unsigned long long input0_num = get_total_elements(input0_dims_num, input0_dims);
35 unsigned long long input1_num = get_total_elements(input1_dims_num, input1_dims);
36 unsigned long long cond_num = get_total_elements(cond_dims_num, cond_dims);
37 unsigned long long output_num = get_total_elements(output_dims_num, output_dims);
38
39 if((input0_num == output_num && input1_num == output_num) && (cond_num == output_num)) {
40 is_broadcast = 0;
41 } else {
42 is_broadcast = 1;
43 }
44
45 params[1] = (unsigned long long)output_dims_num;
46 params[5] = (unsigned long long)is_broadcast;
47
48 srand(seed++);
49 int i;
50
51 //初始化input0, input1, condition
52 for (i = 0; i < input0_num; ++i) {
53 input0[i] = (float)(rand() % 100) / 10.0f;
54 }
55
56 for (i = 0; i < input1_num; ++i) {
57 input1[i] = (float)(rand() % 100) / 10.0f;
58 }
59
60 for (i = 0; i < cond_num; ++i) {
61 condition[i] = (bool)(rand() % 2);
62 }
63 int core_mask = 0x0f;
64 if(is_broadcast) {
65 GetBroadCastIndex(cond_dims, cond_dims_num, output_dims, output_dims_num, index_list1);
66 GetBroadCastIndex(input0_dims, input0_dims_num, output_dims, output_dims_num, index_list2);
67 GetBroadCastIndex(input1_dims, input1_dims_num, output_dims, output_dims_num, index_list3);
68
69 fp_select_s(input0, input1, condition, output, params, core_mask);
70
71 } else {
72 fp_select_s(input0, input1, condition, output, params, core_mask);
73 }
74 return 0;
75}
私有存储版本:
-
void i8_select_p(int8_t *input0, int8_t *input1, bool *condition, int8_t *output, long long *params)
-
void i16_select_p(int16_t *input0, int16_t *input1, bool *condition, int16_t *output, long long *params)
-
void i32_select_p(int32_t *input0, int32_t *input1, bool *condition, int32_t *output, long long *params)
-
void hp_select_p(half *input0, half *input1, bool *condition, half *output, long long *params)
-
void fp_select_p(float *input0, float *input1, bool *condition, float *output, long long *params)
-
void dp_select_p(double *input0, double *input1, bool *condition, double *output, long long *params)
-
void c64_select_p(float *input0, float *input1, bool *condition, float *output, long long *params)
-
void c128_select_p(double *input0, double *input1, bool *condition, double *output, long long *params)
C调用示例(私有存储版本):
1//FT78NE示例
2#include <stdio.h>
3#include <select.h>
4
5int main(int argc, char* argv[]) {
6 float *input0 = (float *)0x10010000;
7 float *input1 = (float *)0x10020000;
8 bool *condition = (bool *)0x10030000;
9 float *output = (float *)0x10040000;
10 float *checkoutput = (float *)0x10050000;
11 unsigned long long *index_list1 = (unsigned long long *)0x10060000;
12 unsigned long long *index_list2 = (unsigned long long *)0x10070000;
13 unsigned long long *index_list3 = (unsigned long long *)0x10078000;
14
15 long long is_broadcast = 0;
16
17 unsigned long long *input0_dims = global_input0_dims;
18 unsigned long long *input1_dims = global_input1_dims;
19 unsigned long long *cond_dims= global_cond_dims;
20 unsigned long long *output_dims = global_output_dims;
21 unsigned long long input0_dims_num = global_input0_dims_num;
22 unsigned long long input1_dims_num = global_input1_dims_num;
23 unsigned long long cond_dims_num = global_cond_dims_num;
24 unsigned long long output_dims_num = global_output_dims_num;
25
26 unsigned long long params[10];
27
28 params[0] = (unsigned long long)output_dims;
29 params[2] = (unsigned long long)index_list1;
30 params[3] = (unsigned long long)index_list2;
31 params[4] = (unsigned long long)index_list3;
32
33 //先计算is_broadcast
34 unsigned long long input0_num = get_total_elements(input0_dims_num, input0_dims);
35 unsigned long long input1_num = get_total_elements(input1_dims_num, input1_dims);
36 unsigned long long cond_num = get_total_elements(cond_dims_num, cond_dims);
37 unsigned long long output_num = get_total_elements(output_dims_num, output_dims);
38
39 if((input0_num == output_num && input1_num == output_num) && (cond_num == output_num)) {
40 is_broadcast = 0;
41 } else {
42 is_broadcast = 1;
43 }
44
45 params[1] = (unsigned long long)output_dims_num;
46 params[5] = (unsigned long long)is_broadcast;
47
48 srand(seed++);
49 int i;
50
51 //初始化input0, input1, condition
52 for (i = 0; i < input0_num; ++i) {
53 input0[i] = (float)(rand() % 100) / 10.0f;
54 }
55
56 for (i = 0; i < input1_num; ++i) {
57 input1[i] = (float)(rand() % 100) / 10.0f;
58 }
59
60 for (i = 0; i < cond_num; ++i) {
61 condition[i] = (bool)(rand() % 2);
62 }
63
64 if(is_broadcast) {
65 GetBroadCastIndex(cond_dims, cond_dims_num, output_dims, output_dims_num, index_list1);
66 GetBroadCastIndex(input0_dims, input0_dims_num, output_dims, output_dims_num, index_list2);
67 GetBroadCastIndex(input1_dims, input1_dims_num, output_dims, output_dims_num, index_list3);
68
69 fp_select_p(input0, input1, condition, output, params);
70
71 } else {
72 fp_select_p(input0, input1, condition, output, params);
73 }
74
75 return 0;
76}